import numpy as np
import pandas as pd
from keras import callbacks
from keras.models import load_model
import keras
import keras.optimizers as opt
from keras import Input, layers
from keras.models import Model
import matplotlib.pyplot as plt
import keras.backend as K
import os
from find_higher_1000 import citys_higher,city_list
city_map={"A":118,"B":30,"C":135,"D":75,"E":34,"F":331,"G":38,"H":53,"I":33,"J":8,"K":48}
string_list=["less_than_200","200_350","350_520","520_850","850_1300","1300_2600","2600_5500","greater_equal_5500"]
#{'A': [4, 9, 10, 11, 13, 14, 16, 18, 19, 20, 26, 29, 30, 31, 32, 35, 37, 38, 39, 40, 41, 42, 43, 44, 47, 49, 50, 51, 52, 53, 54, 56, 57, 61, 62, 64, 65, 66, 67, 68, 75, 77, 79, 80, 93, 99, 104, 107, 108, 110], 'B': [16, 20], 'C': [11, 18, 27, 28, 32, 33, 34, 36, 38, 39, 45, 49, 53, 57, 58, 63, 66, 69, 70, 76, 77, 89, 90, 103, 108, 109, 112, 118, 132], 'D': [11, 13, 20, 39], 'E': [7], 'F': [18, 21, 22, 53, 70, 85, 95, 105, 107, 108, 109, 112, 120, 121, 122, 123, 124, 128, 131, 132, 133, 134, 135, 136, 137, 143, 145, 146, 147, 148, 149, 162, 163, 164, 165,177, 178, 179, 185, 186, 192, 195, 196, 211, 212, 230, 235, 242, 246, 252, 256, 261, 267, 271, 277, 278, 283, 284, 292, 296, 303, 316, 319], 'G': [11, 25], 'H': [4,  21, 47, 48, 49], 'I': [26], 'J': [2], 'K': [5, 9, 10, 23, 26, 30, 33, 39, 40, 46]}
city_map_list=citys_higher

callback_list = [
            callbacks.EarlyStopping(monitor="loss", patience=40),
            callbacks.ReduceLROnPlateau(monitor="loss", factor=0.8, verbose=1, patience=12)
        ]

def reduce(x):
    rate = 1.05
    if (max(target) < 200):
        return x / (1600 / (rate * max(target)))
    elif (max(target) >= 200 and max(target) < 350):
        return x / (3700 / (rate * max(target)))
    elif (max(target) >= 350 and max(target) < 520):
        return x / (7000 / (rate * max(target)))
    elif (max(target) >= 520 and max(target) < 850):
        return x / (10000 / (rate * max(target)))
    elif (max(target) >= 850 and max(target) < 1300):
        return x / (15000 / (rate * max(target)))
    elif (max(target) >= 1300 and max(target) < 2600):
        return x / (27000 / (rate * max(target)))
    elif (max(target) >= 2600 and max(target) < 5500):
        return x / (50000 / (rate * max(target)))
    else:
        return x / (180000 / (rate * max(target)))


def draw(train_data, target,predict_em,chs):
    plt.figure()
    if(chs==0):
        plt.plot(np.arange(61, 91), predict_em, label="train_model")
    else:
        plt.plot(np.arange(1, 91), predict_em, label="train_model")
    plt.plot(train_data, target, label="data")
    plt.legend(loc="best")
    if(not os.path.exists("F:\predict_adjust_{}".format(city))):
        os.makedirs("F:\predict_adjust_{}".format(city))
    plt.savefig("F:\predict_adjust_{}\{}城{}区{}拟合".format(city, city, str(i), str(j)))
    plt.close("all")

def save_data(result,i,city):
    if (city_map_list[city].index(i) == 0):  # 改
        result_final = pd.DataFrame(result, index=[i] * result.shape[0], columns=["天数", "感染人数"])
        result_final.to_csv("predict_adjust_{}.csv".format(city), columns=["天数", "感染人数"])
    else:
        result_final = pd.read_csv("predict_adjust_{}.csv".format(city), names=["天数", "感染人数"])
        result = pd.DataFrame(result, index=[i] * result.shape[0], columns=["天数", "感染人数"])
        result_final = pd.concat([result_final, result])
        result_final.to_csv("predict_adjust_{}.csv".format(city), columns=["天数", "感染人数"])

def sort(df):
    df = df.sort_values(by=["区域", "日期"])
    print(df)
    print(list(set(df["区域"])) == list(range(0, max(df["区域"]) + 1)))
    df.to_csv("infection_{}.csv".format(city), header=False, index=False)


def RMSLE(y_true, y_pred):
    first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
    second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
    return K.sqrt(K.mean(K.square(first_log - second_log)))


for city in city_list[city_list.index("F"):]:
    df = pd.read_csv("infection_{}.csv".format(city), names=["城市", "区域", "日期", "增加人数"])
    df = df.drop(columns=["城市", "日期"])
    for i in city_map_list[city] if city!="F" else city_map_list[city][city_map_list[city].index(235):]:#改
        area = df[df["区域"] == i]
        area = area.reset_index()
        area["index"] = (area["index"]) % 60 + 1
        area.columns = ["天数", "区域", "增加人数"]
        train_data = area["天数"]
        train_data = np.array(train_data)
        target = area["增加人数"]
        target = np.array(target)

        if(target[59]<0.038*max(target)):
            des_part = (target[59] - 0) / 30
            result = list()
            for j in range(30):
                if (j == 0):
                    result.append(target[59] - des_part)
                else:
                    result.append(result[j - 1] - des_part)
            x_test = np.arange(61, 91, 1)
            predict_em = np.array(result)
            predict_em = np.where(predict_em >= 0, predict_em, 0)
            result = np.concatenate((x_test.reshape(30, 1), predict_em.reshape(30, 1)), axis=1)
            draw(train_data,target,predict_em,0)
            save_data(result,i,city)

        else:
            best_model = None
            best_model_loss = None

            if (max(target) < 200):
                model__1 = load_model("toy_v5_{}.h5".format(string_list[0]))
            elif (max(target) >= 200 and max(target) < 350):
                model__1 = load_model("toy_v5_{}.h5".format(string_list[1]))
            elif (max(target) >= 350 and max(target) < 520):
                model__1 = load_model("toy_v5_{}.h5".format(string_list[2]))
            elif (max(target) >= 520 and max(target) < 850):
                model__1 = load_model("toy_v5_{}.h5".format(string_list[3]))
            elif (max(target) >= 850 and max(target) < 1300):
                model__1 = load_model("toy_v5_{}.h5".format(string_list[4]))
            elif (max(target) >= 1300 and max(target) < 2600):
                model__1 = load_model("toy_v5_{}.h5".format(string_list[5]))
            elif (max(target) >= 2600 and max(target) < 5500):
                model__1 = load_model("toy_v5_{}.h5".format(string_list[6]))
            else:
                model__1 = load_model("toy_v5_{}.h5".format(string_list[7]))

            model__1.trainable = False
            model__1.name = "model_1"

            for count in range(1,16):
                for j in range(1,7):
                    data_input = Input(shape=(1,))
                    x=layers.BatchNormalization()(data_input)
                    x=layers.Dense(32,activation="relu")(x)
                    x = layers.Dense(16, activation="relu")(x)
                    y = layers.Dense(16, activation="relu")(x)

                    y = layers.Dense(1)(y)
                    y=layers.BatchNormalization()(y)
                    model__2 = Model(inputs=data_input, outputs=y)
                    model__2.name = "model_2"

                    data_input = Input(shape=(1,))
                    z = layers.Lambda(reduce)(data_input)
                    z = layers.Dense(64)(z)
                    predict_3 = layers.Dense(1)(z)
                    model__3 = Model(data_input, predict_3)
                    model__3.name = "model_3"

                    ensemble_input = keras.Input(shape=(1,))
                    ensemble_output = model__3(model__1(model__2(ensemble_input)))
                    ensemble_model = Model(ensemble_input, ensemble_output)

                    ensemble_model.compile(optimizer=opt.adam(), loss="mse")
                    ensemble_model.fit(train_data, target, epochs=3000, batch_size=60, callbacks=callback_list)
                    y = ensemble_model.predict(train_data)
                    y = y.reshape((60,))
                    print(((y - target) ** 2).mean())
                    print([count] * 50)
                    print([j] * 50)
                    if(j==1):
                        best_model=ensemble_model
                        best_model_loss=((y - target) ** 2).mean()
                    else:
                        if(best_model_loss>((y - target) ** 2).mean()):
                            best_model=ensemble_model
                            best_model_loss=((y - target) ** 2).mean()



                ensemble_model=best_model
                x_test = np.arange(61, 91, 1)
                y_test = ensemble_model.predict(x_test)
                y_test = np.where(y_test >= 0, y_test, 0)
                def judge_descent(y_test):
                    count_des =0
                    for i in range(29):
                        if(y_test[i][0]<y_test[i+1][0] or (y_test[i][0]==y_test[i+1][0] and y_test[i][0]>0.1*max(target))):
                            count_des=count_des+1
                    if(count_des<=2):
                        return True
                    return False

                def judge_tail(y_test,target):
                    if(y_test[0][0]<=1.8*target[59] and y_test[0][0]>=0.2*target[59]):
                        return True
                    return False



                if (max(target) > 1000):
                    if (y_test[29][0] <= 0.06 * max(target) and judge_descent(y_test) and target[59] >= y_test[29][0] and y_test[11][0]<=0.2*target[59]):
                        break
                elif (target[59] < 0.35 * max(target)):
                    if (target[59] >= y_test[29][0] and judge_descent(y_test) and y_test[29][0] < 0.15 * max(target)):
                        break
                else:
                    if(max(target)<100):
                        if (target[59] >= y_test[29][0] and judge_descent(y_test) and y_test[11][0]>0.15*target[59] and y_test[29][0] < 0.3 * max(target)):
                            break
                    else:
                        if (target[59] >= y_test[29][0] and judge_descent(y_test) and y_test[11][0]>=0.05*max(target) and y_test[29][0] < 0.3 * max(target)):
                            break

            ensemble_model = best_model
            x_test = np.arange(61, 91, 1)
            y_test = ensemble_model.predict(x_test)
            y_test = np.where(y_test >= 0, y_test, 0)
            result = np.concatenate((x_test.reshape(30, 1), y_test.reshape(30, 1)), axis=1)
            predict_em = ensemble_model.predict(np.arange(1, 91)).reshape(90, )
            predict_em = np.where(predict_em >= 0, predict_em, 0)
            draw(train_data, target, predict_em, 1)
            save_data(result, i,city)